{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "0510.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/0510.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_GZdOXZ_jlTJ",
        "colab_type": "text"
      },
      "source": [
        "## 10.12 机器翻译"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UOq-eM-7C5g9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install mxnet d2lzh "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gh_171DXsEwa",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import collections \n",
        "import io \n",
        "import math \n",
        "from mxnet import autograd, gluon, init, nd \n",
        "from mxnet.contrib import text \n",
        "from mxnet.gluon import data as gdata, loss as gloss, nn, rnn \n",
        "\n",
        "PAD, BOS, EOS = '<pad>', '<bos>', '<eos>'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Nk_tEc7u1c2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def process_one_seq(seq_tokens, all_tokens, all_seqs, max_seq_len):\n",
        "    all_tokens.extend(seq_tokens)\n",
        "    seq_tokens += [EOS] + [PAD] * (max_seq_len - len(seq_tokens) - 1)\n",
        "    all_seqs.append(seq_tokens)\n",
        "\n",
        "def build_data(all_tokens, all_seqs):\n",
        "    vocab = text.vocab.Vocabulary(collections.Counter(all_tokens), reserved_tokens=[PAD, BOS, EOS])\n",
        "    indices = [vocab.to_indices(seq) for seq in all_seqs]\n",
        "    return vocab, nd.array(indices)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rEAIbikvwwvk",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 151
        },
        "outputId": "06c15ba0-2b21-458f-ecc8-74e595d52076"
      },
      "source": [
        "!git clone https://www.github.com/d2l-ai/d2l-zh.git"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'd2l-zh'...\n",
            "warning: redirecting to https://github.com/d2l-ai/d2l-zh.git/\n",
            "remote: Enumerating objects: 17, done.\u001b[K\n",
            "remote: Counting objects: 100% (17/17), done.\u001b[K\n",
            "remote: Compressing objects: 100% (13/13), done.\u001b[K\n",
            "remote: Total 15702 (delta 9), reused 8 (delta 4), pack-reused 15685\u001b[K\n",
            "Receiving objects: 100% (15702/15702), 159.56 MiB | 32.84 MiB/s, done.\n",
            "Resolving deltas: 100% (11132/11132), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lh1UojwNxA-g",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../data "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K2RCEmDpxEtQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp ./d2l-zh/data/fr-en-small.txt ../data/"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V5uvjkfDxRQX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "8835f9ed-12b8-4b00-bde7-39654bcfef01"
      },
      "source": [
        "!ls ../data/"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "fr-en-small.txt\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qKCKqv_cxUZg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def read_data(max_seq_len):\n",
        "    in_tokens, out_tokens, in_seqs, out_seqs = [], [], [], []\n",
        "    with io.open('../data/fr-en-small.txt') as f:\n",
        "        lines = f.readlines()\n",
        "    for line in lines:\n",
        "        in_seq, out_seq = line.rstrip().split('\\t')\n",
        "        in_seq_tokens, out_seq_tokens = in_seq.split(' '), out_seq.split(' ')\n",
        "        if max(len(in_seq_tokens), len(out_seq_tokens)) > max_seq_len - 1:\n",
        "            continue \n",
        "        process_one_seq(in_seq_tokens, in_tokens, in_seqs, max_seq_len)\n",
        "        process_one_seq(out_seq_tokens, out_tokens, out_seqs, max_seq_len)\n",
        "    in_vocab, in_data = build_data(in_tokens, in_seqs)\n",
        "    out_vocab, out_data = build_data(out_tokens, out_seqs)\n",
        "    return in_vocab, out_vocab, gdata.ArrayDataset(in_data, out_data)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WsBVrOgt0EZI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "67e1a3f7-2ac1-4dfd-bdf5-70b8d273330f"
      },
      "source": [
        "max_seq_len = 7 \n",
        "in_vocab, out_vocab, dataset = read_data(max_seq_len)\n",
        "dataset[0]"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(\n",
              " [ 6.  5. 46.  4.  3.  1.  1.]\n",
              " <NDArray 7 @cpu(0)>, \n",
              " [ 9.  5. 28.  4.  3.  1.  1.]\n",
              " <NDArray 7 @cpu(0)>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zkPt29qA0TLg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Encoder(nn.Block):\n",
        "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, drop_prob=0, **kwargs):\n",
        "        super(Encoder, self).__init__(**kwargs)\n",
        "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
        "        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=drop_prob)\n",
        "\n",
        "    def forward(self, inputs, state):\n",
        "        embedding = self.embedding(inputs).swapaxes(0, 1)\n",
        "        return self.rnn(embedding, state)\n",
        "\n",
        "    def begin_state(self, *args, **kwargs):\n",
        "        return self.rnn.begin_state(*args, **kwargs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DdCXA0yA2qfa",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "eeeb287d-a09b-4818-e317-788f615d961d"
      },
      "source": [
        "encoder = Encoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)\n",
        "encoder.initialize()\n",
        "output, state = encoder(nd.zeros((4, 7)), encoder.begin_state(batch_size=4))\n",
        "output.shape, state[0].shape "
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((7, 4, 16), (2, 4, 16))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JJ0-fSje4gDN",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "2d3c7f81-62e3-4a6b-ff49-dbf7ec0be365"
      },
      "source": [
        "dense = nn.Dense(2, flatten=False)\n",
        "dense.initialize()\n",
        "dense(nd.zeros((3, 5, 7))).shape "
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(3, 5, 2)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d78HvoeB5HU2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def attention_model(attention_size):\n",
        "    model = nn.Sequential()\n",
        "    model.add(nn.Dense(attention_size, activation='tanh', use_bias=False, flatten=False), \n",
        "         nn.Dense(1, use_bias=False, flatten=False))\n",
        "    return model "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XTJ1romE6R9w",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def attention_forward(model, enc_states, dec_state):\n",
        "    dec_states = nd.broadcast_axis(dec_state.expand_dims(0), axis=0, size=enc_states.shape[0])\n",
        "    enc_and_dec_states = nd.concat(enc_states, dec_states, dim=2)\n",
        "    e = model(enc_and_dec_states)\n",
        "    alpha = nd.softmax(e, axis=0)\n",
        "    return (alpha * enc_states).sum(axis=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-Gvs7Zy57gWf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "6815a2e0-3399-4bf9-cf88-8d8c5badf30d"
      },
      "source": [
        "seq_len, batch_size, num_hiddens = 10, 4, 8\n",
        "model = attention_model(10)\n",
        "model.initialize()\n",
        "enc_states = nd.zeros((seq_len, batch_size, num_hiddens))\n",
        "dec_state = nd.zeros((batch_size, num_hiddens))\n",
        "attention_forward(model, enc_states, dec_state).shape "
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(4, 8)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KD9U5Ywy9PqK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Decoder(nn.Block):\n",
        "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, attention_size, drop_prob=0, **kwargs):\n",
        "        super(Decoder, self).__init__(**kwargs)\n",
        "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
        "        self.attention = attention_model(attention_size)\n",
        "        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=drop_prob)\n",
        "        self.out = nn.Dense(vocab_size, flatten=False)\n",
        "\n",
        "    def forward(self, cur_input, state, enc_states):\n",
        "        c = attention_forward(self.attention, enc_states, state[0][-1])\n",
        "        input_and_c = nd.concat(self.embedding(cur_input), c, dim=1)\n",
        "        output, state = self.rnn(input_and_c.expand_dims(0), state)\n",
        "        output = self.out(output).squeeze(axis=0)\n",
        "        return output, state \n",
        "\n",
        "    def begin_state(self, enc_state):\n",
        "        return enc_state "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9ue1l7p1VbH_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def batch_loss(encoder, decoder, X, Y, loss):\n",
        "    batch_size = X.shape[0]\n",
        "    enc_state = encoder.begin_state(batch_size=batch_size)\n",
        "    enc_outputs, enc_state = encoder(X, enc_state)\n",
        "    dec_state = decoder.begin_state(enc_state)\n",
        "    dec_input = nd.array([out_vocab.token_to_idx[BOS]] * batch_size)\n",
        "    mask, num_not_pad_tokens = nd.ones(shape=(batch_size,)), 0 \n",
        "    l = nd.array([0])\n",
        "    for y in Y.T:\n",
        "        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)\n",
        "        l = l + (mask * loss(dec_output, y)).sum()\n",
        "        dec_input = y \n",
        "        num_not_pad_tokens += mask.sum().asscalar()\n",
        "        mask = mask * (y != out_vocab.token_to_idx[EOS])\n",
        "    return l / num_not_pad_tokens "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rTGAuAKpW8qZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train(encoder, decoder, dataset, lr, batch_size, num_epochs):\n",
        "    encoder.initialize(init.Xavier(), force_reinit=True)\n",
        "    decoder.initialize(init.Xavier(), force_reinit=True)\n",
        "    enc_trainer = gluon.Trainer(encoder.collect_params(), 'adam', {'learning_rate': lr})\n",
        "    dec_trainer = gluon.Trainer(decoder.collect_params(), 'adam', {'learning_rate': lr})\n",
        "    loss = gloss.SoftmaxCrossEntropyLoss()\n",
        "    data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True)\n",
        "    for epoch in range(num_epochs):\n",
        "        l_sum = 0.0 \n",
        "        for X, Y in data_iter:\n",
        "            with autograd.record():\n",
        "                l = batch_loss(encoder, decoder, X, Y, loss)\n",
        "            l.backward()\n",
        "            enc_trainer.step(1)\n",
        "            dec_trainer.step(1)\n",
        "            l_sum += l.asscalar()\n",
        "        if (epoch + 1) % 10 == 0:\n",
        "            print(\"epoch %d, loss %.3f\" % (epoch + 1, l_sum / len(data_iter))) "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JjqaKpopYX2Q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "b0ab2008-4d36-4bd9-ab96-48a7d09922f8"
      },
      "source": [
        "embed_size, num_hiddens, num_layers = 64, 64, 2\n",
        "attention_size, drop_prob, lr, batch_size, num_epochs = 10, 0.5, 0.01, 2, 50 \n",
        "encoder = Encoder(len(in_vocab), embed_size, num_hiddens, num_layers, drop_prob)\n",
        "decoder = Decoder(len(out_vocab), embed_size, num_hiddens, num_layers, attention_size, drop_prob)\n",
        "train(encoder, decoder, dataset, lr, batch_size, num_epochs)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 10, loss 0.426\n",
            "epoch 20, loss 0.244\n",
            "epoch 30, loss 0.143\n",
            "epoch 40, loss 0.103\n",
            "epoch 50, loss 0.053\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FTiYNsAhZQ3M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def translate(encoder, decoder, input_seq, max_seq_len):\n",
        "    in_tokens = input_seq.split(' ')\n",
        "    in_tokens += [EOS] + [PAD] * (max_seq_len - len(in_tokens) - 1)\n",
        "    enc_input = nd.array([in_vocab.to_indices(in_tokens)])\n",
        "    enc_state = encoder.begin_state(batch_size=1)\n",
        "    enc_output, enc_state = encoder(enc_input, enc_state)\n",
        "    dec_input = nd.array([out_vocab.token_to_idx[BOS]])\n",
        "    dec_state = decoder.begin_state(enc_state)\n",
        "    output_tokens = []\n",
        "    for _ in range(max_seq_len):\n",
        "        dec_output, dec_state = decoder(dec_input, dec_state, enc_output)\n",
        "        pred = dec_output.argmax(axis=1)\n",
        "        pred_token = out_vocab.idx_to_token[int(pred.asscalar())]\n",
        "        if pred_token == EOS:\n",
        "            break \n",
        "        else:\n",
        "            output_tokens.append(pred_token)\n",
        "            dec_input = pred \n",
        "    return output_tokens "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q_-wN30ybM6Y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "c3e6666e-f86b-4ebc-b992-e8028f9a86df"
      },
      "source": [
        "input_seq = 'ila regardent .'\n",
        "translate(encoder, decoder, input_seq, max_seq_len)"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['they', 'are', 'watching', '.']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KJbDwXd9bZb6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def bleu(pred_tokens, label_tokens, k):\n",
        "    len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
        "    score = math.exp(min(0, 1 - len_label / len_pred))\n",
        "    for n in range(1, k + 1):\n",
        "        num_matches, label_subs = 0, collections.defaultdict(int)\n",
        "        for i in range(len_label - n + 1):\n",
        "            label_subs[''.join(label_tokens[i: i + n])] += 1 \n",
        "        for i in range(len_pred - n + 1):\n",
        "            if label_subs[''.join(pred_tokens[i : i + n])] > 0:\n",
        "                num_matches += 1 \n",
        "                label_subs[''.join(pred_tokens[i: i + n])] -= 1 \n",
        "        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n",
        "    return score "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A6H5EU9Igk7M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def score(input_seq, label_seq, k):\n",
        "    pred_tokens = translate(encoder, decoder, input_seq, max_seq_len)\n",
        "    label_tokens = label_seq.split(' ')\n",
        "    print('bleu %.3f, predict: %s' % (bleu(pred_tokens, label_tokens, k), ' '.join(pred_tokens)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h61HqJ7shDgP",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b427b50e-80a4-46cc-eaa3-5a1063b79ae3"
      },
      "source": [
        "score('ils regardent .', 'they are watching .', k=2)"
      ],
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bleu 1.000, predict: they are watching .\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s8BpG5DEhPAD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "ab1c2a7e-08c1-46a8-e183-f972cd6ca13c"
      },
      "source": [
        "score('ils sont canadiens .', 'they are canadian .', k=2)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bleu 0.658, predict: they are russian .\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}